from tqdm import tqdm
import subprocess
import json
import shutil
import re
import os


def remove_think_process(text):
    pattern = re.compile(r"\s*<think\b[^>]*>.*?</think>\s*", re.IGNORECASE | re.DOTALL)
    cleaned = pattern.sub("", text)
    # Collapse excessive blank lines left behind
    cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
    return cleaned.strip()

def extract_python_code_block(text: str):
    pattern = re.compile(r"```(?:\w+)?\n(.*?)```", re.DOTALL)
    match = pattern.search(text)
    if match:
        code_block = match.group(1).strip()
        return code_block, True
    else:
        return text.strip(), False


def validate_manim_code(code_string: str, work_dir: str):
    scene_match = re.search(r"class\s+(\w+)\(Scene\):", code_string)
    if not scene_match:
        return False, "验证错误: 未找到 'class YourScene(Scene):' 定义。"
    scene_name = scene_match.group(1)

    temp_code_path = os.path.join(work_dir, "temp_scene.py")

    try:
        with open(temp_code_path, "w", encoding="utf-8") as f:
            f.write(code_string)
    except IOError as e:
        return False, f"验证错误: 无法写入临时文件: {e}"

    command = ["manim", "-qm", temp_code_path, scene_name]
    error_message = ""
    is_valid = False
    
    try:
        result = subprocess.run(
            command,
            cwd=work_dir,
            capture_output=True,
            text=True,
            encoding="utf-8",
            timeout=120
        )
        if result.returncode == 0:
            is_valid = True
        else:
            error_message = f"Manim渲染失败，退出码 {result.returncode}。\n--- STDOUT ---\n{result.stdout}\n--- STDERR ---\n{result.stderr}"
    except FileNotFoundError:
        error_message = "验证错误: 'manim' 命令未找到。Manim是否已安装并添加到系统路径(PATH)?"
    except subprocess.TimeoutExpired:
        error_message = "验证错误: Manim渲染超时 (>120秒)。"
    except Exception as e:
        error_message = f"发生未知的验证错误: {e}"
    
    print(error_message)
    
    return is_valid, error_message if not is_valid else None


VIDEO_EXTS = {".mp4", ".mov", ".mkv", ".webm"}
IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif"}

def pick_largest(paths):
    """Heuristic: choose the largest file (usually the final render)."""
    return max(paths, key=lambda p: (os.path.getsize(p), os.path.getmtime(p)))

def collect_outputs(root_dir: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)

    # Sort subfolders numerically; ignore non-numeric
    def sort_key(n):
        return (0, int(n)) if n.isdigit() else (1, n)

    for name in sorted(os.listdir(root_dir), key=sort_key):
        subdir = os.path.join(root_dir, name)
        if not (os.path.isdir(subdir) and name.isdigit()):
            continue

        finals, partials, images = [], [], []

        # Walk the index directory
        for dirpath, _, filenames in os.walk(subdir):
            if os.path.basename(dirpath) == "__pycache__":
                continue
            for fn in filenames:
                ext = os.path.splitext(fn)[1].lower()
                fpath = os.path.join(dirpath, fn)

                if ext in VIDEO_EXTS:
                    if "partial_movie_files" in dirpath:
                        partials.append(fpath)
                    else:
                        finals.append(fpath)
                elif ext in IMAGE_EXTS:
                    images.append(fpath)

        # Decide what to export for this index
        tag = ""
        candidate = None
        if finals:
            candidate = pick_largest(finals)
        elif partials:
            candidate = pick_largest(partials)
            tag = "_partial"
        elif images:
            candidate = pick_largest(images)
        else:
            print(f"[{name}] No video/image found. Skipping.")
            continue

        ext = os.path.splitext(candidate)[1].lower()
        dst = os.path.join(output_dir, f"{name}{tag}{ext}")

        # Avoid accidental overwrite if a file already exists
        if os.path.exists(dst):
            i = 1
            while True:
                alt = os.path.join(output_dir, f"{name}{tag}_{i}{ext}")
                if not os.path.exists(alt):
                    dst = alt
                    break
                i += 1

        shutil.copy2(candidate, dst)
        kind = "video" if ext in VIDEO_EXTS else "image"
        print(f"[{name}] Copied {kind}{' (partial)' if tag else ''}: {candidate} -> {dst}")

if __name__ == "__main__":
    model_name = 'Qwen2.5-Coder-7B-Instruct'
    source_fp = '/gen_code_manim_{0}.jsonl'.format(model_name)
    output_base = '/outputs/animation'
    animation_dir = os.path.join(output_base, model_name)
    os.makedirs(animation_dir, exist_ok=True)
    success_cnt = 0
    with open(source_fp, 'r') as f_in:
        for line_idx, line in tqdm(enumerate(f_in)):
            line_parsed = json.loads(line)
            instruction = line_parsed['instruction']
            code = line_parsed['code']
            code_cleaned = remove_think_process(code)
            code_extracted, _ = extract_python_code_block(code_cleaned)
            sample_dir = os.path.join(animation_dir, str(line_idx))
            os.makedirs(sample_dir, exist_ok=True)
            success, message = validate_manim_code(code_extracted, sample_dir)
            if success:
                success_cnt += 1
    
    print('Success count:', success_cnt)

    extract_animation_base = "/extracted_animation/"
    extract_animation_dir = os.path.join(extract_animation_base, model_name)

    os.makedirs(extract_animation_dir, exist_ok=True)
    collect_outputs(animation_dir, extract_animation_dir)